//===--- FunctionSignatureTransforms.swift ---------------------------------==//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import SIL

/// Replace an apply with metatype arguments with an apply to a specialized function, where the
/// metatype values are not passed, but rematerialized in the entry block of the specialized function
///
/// ```
///   func caller() {
///     callee(Int.self)
///   }
///   func callee(_ t: Int.Type) {  // a thick metatype
///     // ...
///   }
/// ```
/// ->
/// ```
///   func caller() {
///     specialized_callee()
///   }
///   func specialized_callee() {
///     let t: Int.Type = Int.self
///     // ...
///   }
///   // remains a thunk
///   func callee(_ t: Int.Type) {
///      specialized_callee()
///   }
/// ```
///
func specializeByRemovingMetatypeArguments(apply: FullApplySite, _ context: ModulePassContext) {
  guard let callee = apply.referencedFunction,
        !callee.isGeneric
  else {
    return
  }

  let deadArgIndices = callee.argumentTypes.enumerated()
                         .filter { $0.element.isRemovableMetatype(in: callee) }
                         .map { $0.offset }
  if deadArgIndices.isEmpty {
    return
  }

  let specializedFuncName = context.mangle(withDeadArguments: deadArgIndices, from: callee)

  let specializedCallee: Function
  if let existingSpecialization = context.lookupFunction(name: specializedFuncName) {
    specializedCallee = existingSpecialization
  } else {
    if !context.loadFunction(function: callee, loadCalleesRecursively: true) {
      return
    }
    specializedCallee = createSpecializedFunction(withName: specializedFuncName,
                                                  withRemovedMetatypeArgumentsOf: apply,
                                                  originalFunction: callee,
                                                  context)
  }

  context.transform(function: apply.parentFunction) { funcContext in
    replace(apply: apply, to: specializedCallee, funcContext)
  }
}

/// Creates a specialized function by moving the whole function body of `originalFunction` to the new specialized
/// function and calling the specialized function in the original function (which is now a thunk).
private func createSpecializedFunction(
  withName name: String,
  withRemovedMetatypeArgumentsOf apply: FullApplySite,
  originalFunction: Function,
  _ context: ModulePassContext
) -> Function {
  let (aliveParameters, hasSelfParameter) = getAliveParameters(of: originalFunction)

  let specializedFunction = context.createEmptyFunction(
    name: name,
    parameters: aliveParameters,
    hasSelfParameter: hasSelfParameter,
    fromOriginal: originalFunction)

  let thunkLoc = originalFunction.entryBlock.instructions.first!.location.autoGenerated

  context.moveFunctionBody(from: originalFunction, to: specializedFunction)
  // originalFunction is now empty and used as the thunk.
  let thunk = originalFunction

  context.transform(function: thunk) { funcContext in
    thunk.set(thunkKind: .signatureOptimizedThunk, funcContext)
    createEntryBlock(in: thunk, usingArguments: specializedFunction.arguments, funcContext)
  }

  context.transform(function: specializedFunction) { funcContext in
    removeMetatypArguments(in: specializedFunction, funcContext)
  }

  context.transform(function: thunk) { funcContext in
    createForwardingApply(to: specializedFunction,
                          in: thunk,
                          originalApply: apply,
                          debugLocation: thunkLoc,
                          funcContext)
  }

  return specializedFunction
}

private func getAliveParameters(of originalFunction: Function) -> ([ParameterInfo], hasSelfParameter: Bool) {
  let convention = originalFunction.convention
  var aliveParams = [ParameterInfo]()
  var hasSelfParameter = originalFunction.hasSelfArgument
  for (paramIdx, origParam) in convention.parameters.enumerated() {
    let argIdx = paramIdx + convention.indirectSILResultCount
    if !originalFunction.argumentTypes[argIdx].isRemovableMetatype(in: originalFunction) {
      aliveParams.append(origParam)
    } else if hasSelfParameter && originalFunction.selfArgumentIndex == argIdx {
      hasSelfParameter = false
    }
  }
  return (aliveParams, hasSelfParameter)
}

private func createEntryBlock(
  in function: Function,
  usingArguments: some Sequence<FunctionArgument>,
  _ context: FunctionPassContext
) {
  let entryBlock = function.appendNewBlock(context)
  for arg in usingArguments {
    _ = entryBlock.addFunctionArgument(type: arg.type, context)
  }
}

private func removeMetatypArguments(in specializedFunction: Function, _ context: FunctionPassContext) {
  let entryBlock = specializedFunction.entryBlock
  var funcArgIdx = 0
  while funcArgIdx < specializedFunction.entryBlock.arguments.count {
    let funcArg = specializedFunction.arguments[funcArgIdx]
    if funcArg.type.isRemovableMetatype(in: specializedFunction) {
      // Rematerialize the metatype value in the entry block.
      let builder = Builder(atBeginOf: entryBlock, context)
      let instanceType = funcArg.type.loweredInstanceTypeOfMetatype(in: specializedFunction)
      let metatype = builder.createMetatype(of: instanceType, representation: .Thick)
      funcArg.uses.replaceAll(with: metatype, context)
      entryBlock.eraseArgument(at: funcArgIdx, context)
    } else {
      funcArgIdx += 1
    }
  }
}

private func createForwardingApply(
  to specializedFunction: Function,
  in thunk: Function,
  originalApply: FullApplySite,
  debugLocation: Location,
  _ context: FunctionPassContext
) {
  let applyArgs = Array(thunk.arguments.filter { !$0.type.isRemovableMetatype(in: thunk) })

  let builder = Builder(atEndOf: thunk.entryBlock, location: debugLocation, context)
  let callee = builder.createFunctionRef(specializedFunction)

  // Use the original apply as template to create the forwarding apply
  switch originalApply {
  case let ai as ApplyInst:
    let newApply = builder.createApply(function: callee,
                                       ai.substitutionMap,
                                       arguments: applyArgs,
                                       isNonThrowing: ai.isNonThrowing,
                                       isNonAsync: ai.isNonAsync,
                                       specializationInfo: ai.specializationInfo)
    builder.createReturn(of: newApply)
  case let tai as TryApplyInst:
    let normalBlock = thunk.appendNewBlock(context)
    let errorBlock = thunk.appendNewBlock(context)
    builder.createTryApply(function: callee,
                           tai.substitutionMap,
                           arguments: applyArgs,
                           normalBlock: normalBlock,
                           errorBlock: errorBlock,
                           specializationInfo: tai.specializationInfo)
    let originalArg = tai.normalBlock.arguments[0]
    let returnVal = normalBlock.addArgument(type: originalArg.type, ownership: originalArg.ownership, context)
    let returnBuilder = Builder(atEndOf: normalBlock, location: debugLocation, context)
    returnBuilder.createReturn(of: returnVal)

    let errorArg = tai.errorBlock.arguments[0]
    let errorVal = errorBlock.addArgument(type: errorArg.type, ownership: errorArg.ownership, context)
    let errorBuilder = Builder(atEndOf: errorBlock, location: debugLocation, context)
    errorBuilder.createThrow(of: errorVal)
  default:
    fatalError("unknown full apply instruction \(originalApply)")
  }
}

private func replace(apply: FullApplySite, to specializedCallee: Function, _ context: FunctionPassContext) {
  let builder = Builder(before: apply, context)
  let callee = builder.createFunctionRef(specializedCallee)
  let args = Array(apply.arguments.filter { !$0.type.isRemovableMetatype(in: apply.parentFunction) })
  switch apply {
  case let ai as ApplyInst:
    let newApply = builder.createApply(function: callee,
                                       ai.substitutionMap,
                                       arguments: args,
                                       isNonThrowing: ai.isNonThrowing,
                                       isNonAsync: ai.isNonAsync,
                                       specializationInfo: ai.specializationInfo)
    ai.uses.replaceAll(with: newApply, context)
  case let tai as TryApplyInst:
    builder.createTryApply(function: callee,
                           tai.substitutionMap,
                           arguments: args,
                           normalBlock: tai.normalBlock,
                           errorBlock: tai.errorBlock,
                           specializationInfo: tai.specializationInfo)
  default:
    fatalError("unknown full apply instruction \(apply)")
  }
  context.erase(instruction: apply)
}

private extension Type {
  func isRemovableMetatype(in function: Function) -> Bool {
    if isMetatype {
      if representationOfMetatype(in: function) == .Thick {
        let instanceTy = loweredInstanceTypeOfMetatype(in: function)
        // For structs and enums we know the metatype statically.
        return instanceTy.isStruct || instanceTy.isEnum
      }
    }
    return false
  }
}
